import bz2
import sys
import os
import string
import subprocess
import pybedtools
from Bio import Entrez
from Bio.Seq import reverse_complement


chromosomes = {'NC_000001': 'chr1',
               'NC_000002': 'chr2',
               'NC_000003': 'chr3',
               'NC_000004': 'chr4',
               'NC_000005': 'chr5',
               'NC_000006': 'chr6',
               'NC_000007': 'chr7',
               'NC_000008': 'chr8',
               'NC_000009': 'chr9',
               'NC_000010': 'chr10',
               'NC_000011': 'chr11',
               'NC_000012': 'chr12',
               'NC_000013': 'chr13',
               'NC_000014': 'chr14',
               'NC_000015': 'chr15',
               'NC_000016': 'chr16',
               'NC_000017': 'chr17',
               'NC_000018': 'chr18',
               'NC_000019': 'chr19',
               'NC_000020': 'chr20',
               'NC_000021': 'chr21',
               'NC_000022': 'chr22',
               'NC_000023': 'chrX',
               'NC_000024': 'chrY',
               'NC_012920': 'chrM',
              }

class Transcript:
    def __init__(self):
        self.chromosome = None
        self.strand = None
        self.exons = []
    def __str__(self):
        return "%s: %s, %s" % (self.chromosome, self.exons, self.strand)

class Transcripts(list):
    def __init__(self):
        self.gene_name  = None
        self.gene_type = None

def find_transcript_locations(record, gene_id):
    transcripts = Transcripts()
    gene_ref = record['Entrezgene_gene']['Gene-ref']
    gene_name = gene_ref.get('Gene-ref_locus')
    if not gene_name:
        gene_name = gene_ref['Gene-ref_syn'][0]
    transcripts.gene_name = gene_name
    if gene_name == 'NEWENTRY':
        return transcripts
    description = gene_ref.get('Gene-ref_desc', '')
    annotated = True
    for comment in record['Entrezgene_comments']:
        if comment.get('Gene-commentary_heading')!='Annotation Information':
            continue
        for property in comment['Gene-commentary_properties']:
            if property['Gene-commentary_text']=='not annotated on reference assembly':
                annotated = False
    if not annotated:
        return transcripts
    for gene_commentary in record['Entrezgene_locus']:
        label = gene_commentary.get('Gene-commentary_label')
        heading = gene_commentary.get("Gene-commentary_heading")
        accession = gene_commentary.get('Gene-commentary_accession')
        assert gene_commentary['Gene-commentary_type'] == '1'
        assert gene_commentary['Gene-commentary_type'].attributes['value']=='genomic'
        if accession is None:  # not annotated
            continue
        if heading is None:  # chrM
            if gene_commentary.get('Gene-commentary_version')=='0':
                continue
            if label in ('RefSeqGene', 'genomic region'):
                continue
            chromosome = chromosomes[accession]
            assert chromosome=='chrM'
        elif heading == 'Reference GRCh38.p13 Primary Assembly':
        #     if label.endswith("Unplaced Scaffold Reference GRCh38.p7 Primary Assembly"):
        #         continue
            if label.endswith("Unlocalized Scaffold Reference GRCh38.p13 Primary Assembly"):
                continue
            chromosome = chromosomes[accession]
            chromosome_number = chromosome[3:]
            assert label=="Chromosome %s Reference GRCh38.p13 Primary Assembly" % chromosome_number
        elif heading in (
              'Reference GRCh38.p13 ALT_REF_LOCI_1',
              'Reference GRCh38.p13 ALT_REF_LOCI_2',
              'Reference GRCh38.p13 ALT_REF_LOCI_3',
              'Reference GRCh38.p13 ALT_REF_LOCI_4',
              'Reference GRCh38.p13 ALT_REF_LOCI_5',
              'Reference GRCh38.p13 ALT_REF_LOCI_6',
              'Reference GRCh38.p13 ALT_REF_LOCI_7',
              'Reference GRCh38.p13 PATCHES',
            ):
            continue
        else:
            raise Exception("Unexpected: '%s'" % heading)
        assert len(gene_commentary['Gene-commentary_seqs'])==1
        loc_int = gene_commentary['Gene-commentary_seqs'][0]['Seq-loc_int']
        interval = loc_int['Seq-interval']
        start = int(interval['Seq-interval_from'])
        end = int(interval['Seq-interval_to'])+1
        strand = interval['Seq-interval_strand']['Na-strand'].attributes['value']
        if strand=='plus':
            strand = '+'
        elif strand=='minus':
            strand = '-'
        products = gene_commentary.get('Gene-commentary_products')
        if not products:
            transcript = Transcript()
            transcript.chromosome = chromosome
            transcript.strand = strand
            transcript.exons = [[start, end]]
            transcripts.append(transcript)
            continue
        for product in products:
            if product['Gene-commentary_type']=="5":
                assert product['Gene-commentary_type'].attributes['value']=='tRNA'
            else:
                continue
            transcript = Transcript()
            transcripts.append(transcript)
            transcript.chromosome = chromosome
            transcript.strand = strand
            assert len(product['Gene-commentary_genomic-coords']) == 1
            coordinates = product['Gene-commentary_genomic-coords'][0]
            assert len(coordinates) == 1
            if 'Seq-loc_int' in coordinates:
                interval = coordinates['Seq-loc_int']['Seq-interval']
                start = int(interval['Seq-interval_from'])
                end = int(interval['Seq-interval_to']) + 1
                if strand=='+':
                    assert interval['Seq-interval_strand']['Na-strand'].attributes['value']=='plus'
                elif strand=='-':
                    assert interval['Seq-interval_strand']['Na-strand'].attributes['value']=='minus'
                transcript.exons.append([start, end])
            elif 'Seq-loc_mix' in coordinates:
                intervals = coordinates['Seq-loc_mix']['Seq-loc-mix']
                for interval in intervals:
                    interval = interval['Seq-loc_int']['Seq-interval']
                    start = int(interval['Seq-interval_from'])
                    end = int(interval['Seq-interval_to']) + 1
                    if strand=='+':
                        assert interval['Seq-interval_strand']['Na-strand'].attributes['value']=='plus'
                    elif strand=='-':
                        assert interval['Seq-interval_strand']['Na-strand'].attributes['value']=='minus'
                    transcript.exons.append([start, end])
                if strand == '-':
                    transcript.exons.reverse()
    return transcripts

def parse_codon_aminoacid(text):
    d = {'alanine': 'Ala',
         'amber suppressor': 'Amb',
         'arginine': 'Arg',
         'asparagine': 'Asn',
         'aspartate': 'Asp',
         'aspartic acid': 'Asp',
         'cysteine': 'Cys',
         'histidine': 'His',
         'isoleucine': 'Ile',
         'glutamine': 'Gln',
         'glutamate': 'Glu',
         'glutamic acid': 'Glu',
         'glycine': 'Gly',
         'leucine': 'Leu',
         'lysine': 'Lys',
         'initiator methionine': 'iMet',
         'methionine': 'Met',
         'ochre suppressor': 'Och',
         'opal suppressor': 'Opl',
         'phenylalanine': 'Phe',
         'proline': 'Pro',
         'selenocysteine': 'SeC',
         'serine': 'Ser',
         'threonine': 'Thr',
         'tryptophan': 'Trp',
         'tyrosine': 'Tyr',
         'valine': 'Val',
        }
    terms = ("transfer RNA ",
             "transfer RNA-",
             "tRNA ",
             "tRNA-",
             "trRNA-",
             "nuclear encoded tRNA ",
             "mitochondrially encoded tRNA ",
             "nuclear-encoded mitochondrial transfer RNA-",
             "nuclear-encoded mitochondrial tRNA-",
            )
    for prefix in terms:
        if text.startswith(prefix):
            text = text[len(prefix):]
            break
    else:
        for term in terms:
            if term in text:
                break
        else:
            raise Exception("Unknown description: '%s'" % text)
        text = text.replace(term, "")
    if '(' in text:
        j = text.index(')')
        text = text[:j]
        aminoacid, text = text.split('(')
        if ' ' in text:
            words = text.split()
            assert words[0]=='anticodon'
            anticodon = words[1]
            codon = reverse_complement(anticodon)
            codon = codon.replace('T', 'U')
        else:
            codon = text
        if len(codon)==5:
            assert codon[3]=='/'
            wobble = codon[2:]
            if wobble=="A/G":
                nucleotide = 'R'
            elif wobble=="U/C":
                nucleotide = 'Y'
            else:
                raise Exception("Unknown wobble")
            codon = codon[:2] + nucleotide
        else:
            assert len(codon)==3
    else:
        if ',' in text:
            aminoacid, text = text.split(",")
            text = text.strip()
            assert text=='mitochondrial'
        else:
            aminoacid = text
        codon = 'Unk'
    aminoacid = aminoacid.strip()
    while aminoacid[-1] in string.digits:
        aminoacid = aminoacid[:-1]
    aminoacid = aminoacid.strip()
    if aminoacid.endswith(" initiator"):
        aminoacid = aminoacid[:-len(" initiator")]
    if aminoacid=='suppressor':
        if codon in ('TTA', 'CTA', 'TCA'):
            codon = reverse_complement(codon)
            codon = codon.replace('T', 'U')
        if codon=='UAA':
            aminoacid = "ochre suppressor"
        elif codon=='UAG':
            aminoacid = "amber suppressor"
        elif codon=='UGA':
            aminoacid = "opal suppressor"
        else:
            raise Exception
    if aminoacid not in d.values():
        aminoacid = d[aminoacid]
    return (codon, aminoacid)

def find_codon_information(record):
    codons = set()
    aminoacids = set()
    for gene_property in record.get('Entrezgene_properties', []):
        for commentary in gene_property.get('Gene-commentary_properties', []):
            if commentary.get('Gene-commentary_label')=='Official Full Name':
                text = commentary['Gene-commentary_text']
                codon, aminoacid = parse_codon_aminoacid(text)
                codons.add(codon)
                aminoacids.add(aminoacid)
    reference = record['Entrezgene_gene']['Gene-ref']
    text = reference.get('Gene-ref_desc')
    if text:
        codon, aminoacid = parse_codon_aminoacid(text)
        codons.add(codon)
        aminoacids.add(aminoacid)
    formal_name = reference.get('Gene-ref_formal-name')
    if formal_name:
        text = formal_name['Gene-nomenclature']['Gene-nomenclature_name']
        codon, aminoacid = parse_codon_aminoacid(text)
        codons.add(codon)
        aminoacids.add(aminoacid)
    if codons:
        codon = codons.pop()
        aminoacid = aminoacids.pop()
        assert len(codons)==0
        assert len(aminoacids)==0
    else:
        raise Exception("Unknown codon")
        codon = "???"
        aminoacid = "???"
    return (codon, aminoacid)

def read_record(record, organism):
    if record['Entrezgene_source']['BioSource']['BioSource_org']['Org-ref']['Org-ref_taxname']!=organism:
        return
    gene_track = record['Entrezgene_track-info']['Gene-track']
    status = gene_track['Gene-track_status'].attributes['value']
    if status in ('discontinued', 'secondary'):
        return
    assert status=='live'
    gene_id = gene_track['Gene-track_geneid']
    gene_type = record['Entrezgene_type'].attributes['value']
    print("%s: %s" % (gene_id, gene_type))
    # sys.stderr.write("%s: %s\n" % (gene_id, gene_type))
    # sys.stderr.flush()
    description = record['Entrezgene_gene']['Gene-ref'].get('Gene-ref_desc', '')
    if gene_type!='tRNA':
        return
    transcripts = find_transcript_locations(record, gene_id)
    gene_name = transcripts.gene_name
    codon, aminoacid = find_codon_information(record)
    # A gene can be mapped to multiple locations. In particular, there are
    # genes that are mapped both to chrX and chrY.
    for transcript in transcripts:
        chromosome = transcript.chromosome
        strand = transcript.strand
        exons = transcript.exons
        blockCount = len(exons)
        blockStarts = []
        blockSizes = []
        chromStart = exons[0][0]
        chromEnd = exons[-1][1]
        for start, end in exons:
            assert start >= chromStart
            assert end <= chromEnd
            blockStart = start - chromStart
            blockSize = end - start
            blockStarts.append(str(blockStart))
            blockSizes.append(str(blockSize))
        name = "%s:%s:%s%s" % (gene_id, gene_name, aminoacid, codon)
        blockStarts = ",".join(blockStarts) + ","
        blockSizes = ",".join(blockSizes) + ","
        fields = [chromosome,
                  chromStart,
                  chromEnd,
                  name,
                  ".",
                  strand,
                  chromStart,
                  chromStart,
                  ".",
                  blockCount,
                  blockSizes,
                  blockStarts,
                 ]
        interval = pybedtools.create_interval_from_list(fields)
        yield interval

def write_transcript_file(organism):
    source = "%s.ags.gz" % organism.replace(" ", "_")
    command = ["gene2xml", "-b", "T", "-c", "T", "-i", source]
    print("Reading", source, "via gene2xml")
    handle = subprocess.Popen(command, stdout=subprocess.PIPE).stdout
    # A gene can be mapped to multiple locations.
    intervals = []
    records = Entrez.parse(handle)
    for record in records:
        for interval in read_record(record, organism):
            intervals.append(interval)
    handle.close()
    sorted_chromosomes = list(chromosomes.values())
    def key(interval):
        chromosome = interval.chrom
        index = sorted_chromosomes.index(chromosome)
        start = interval.start
        return (index, start)
    intervals.sort(key=key)
    time = timestamp(source)
    filename = "trnas.bed"
    print("Writing", filename)
    output = open(filename, 'w')
    output.write('track name="tRNAs" description="RefSeq tRNAs extracted from %s downloaded from NCBI on %s\n' % (source, time))
    for interval in intervals:
        line = str(interval)
        output.write(line)
    output.close()


def timestamp(filename):
    import time, os
    t = os.path.getmtime(filename)
    return time.strftime("%Y.%m.%d", time.localtime(t))


write_transcript_file('Homo sapiens')
